package idgenerator

import (
	"database/sql"
	"fmt"
	"gitee.com/glsoft/go-id-generator/util"
	log "github.com/sirupsen/logrus"
	"io"
	"net"
	"os"
	"os/signal"
	"runtime"
	"sync"
	"sync/atomic"
	"syscall"
)

// Service 服务器组件接口
type Service interface {
	Initialize() error //初始化
	Start() error      // 开始
	Stop() error       //停止
}

// 访问量计数
var connCount int64 = 0

type IdGenService struct {
	sync.RWMutex

	db              *sql.DB
	listener        net.Listener
	keyGeneratorMap map[string]IdGenerator
	running         bool
	ch              chan os.Signal
	dataSource      string
	addr            string
	password        string
	runSync         bool
}

func NewService(dataSource string, addr string, runSync bool) (Service, error) {
	var err error

	//创建监听退出chan
	s := new(IdGenService)
	s.dataSource = dataSource
	s.runSync = runSync
	if len(addr) == 0 {
		addr = "localhost:5678"
	}
	s.addr = addr
	s.password = "admin"

	err = s.Initialize()

	return s, err
}

// Initialize 初始化
func (s *IdGenService) Initialize() error {
	var err error

	s.keyGeneratorMap = make(map[string]IdGenerator)
	s.ch = make(chan os.Signal)

	if nil == s.db {
		db, err := sql.Open("mysql", s.dataSource)
		if nil != err {
			return err
		}
		err = db.Ping()
		if nil != err {
			return err
		}
		s.db = db
	}
	if nil == s.listener {
		s.listener, err = net.Listen("tcp", s.addr)
		if err != nil {
			return err
		}
	}
	return nil
}

// Start 开始
func (s *IdGenService) Start() error {
	createTableNtSQL := fmt.Sprintf(CreateRecordTableNTSQLFormat, KeyRecordTableName)
	selectKeysSQL := fmt.Sprintf(SelectKeysSQLFormat, KeyRecordTableName)
	_, err := s.db.Exec(createTableNtSQL)
	if err != nil {
		return err
	}
	rows, err := s.db.Query(selectKeysSQL)
	if err != nil {
		return err
	}
	defer util.CloseRows(rows)
	for rows.Next() {
		idGenKey := ""
		err := rows.Scan(&idGenKey)
		if err != nil {
			return err
		}
		if idGenKey != "" {
			idGen, ok := s.keyGeneratorMap[idGenKey]
			if ok == false {
				idGen, err = NewIdGenerator(s.db, idGenKey)
				if err != nil {
					return err
				}
				s.keyGeneratorMap[idGenKey] = idGen
			}
		}
	}

	s.running = true
	if s.running {
		go func() {
			for s.running {
				conn, err := s.listener.Accept()
				if err != nil {
					log.Error(err)
				} else {
					go s.run(conn)
				}
			}
		}()
	}
	addr := s.listener.Addr().String()

	log.Infof("id generator running %s ", addr)

	// 注册要处理的信号
	signal.Notify(s.ch, syscall.SIGQUIT)
	if s.runSync {
		// 注册要处理的信号
		for {
			select {
			case sign := <-s.ch:
				switch sign {
				case syscall.SIGQUIT:
					_ = s.db.Close()
					//终端控制进程结束(终端连接断开)
					os.Exit(0)
				}
			}
		}
	}
	return nil
}

// Stop 停止
func (s *IdGenService) Stop() error {
	s.running = false
	if s.listener != nil {
		_ = s.listener.Close()
	}
	log.Infof("id generator service stop!")
	return nil
}

func (s *IdGenService) shutdown() {
	s.ch <- syscall.SIGQUIT
}

func (s *IdGenService) request(request *Request) Reply {
	switch request.Command {
	case "PING":
		return s.handlePing(request)
	case "AUTH":
		return s.handleAuth(request)
	case "GET":
		return s.handleGet(request)
	case "SET":
		return s.handleSet(request)
	case "INCR":
		return s.handleIncr(request)
	case "EXISTS":
		return s.handleExists(request)
	case "DEL":
		return s.handleDel(request)
	case "SELECT":
		return s.handleSelect(request)
	case "QUIT":
		return s.handleQuit(request)
	case "SHUTDOWN":
		return s.handleShutdown(request)
	default:
		return ErrMethodNotSupported
	}
}

func (s *IdGenService) run(conn net.Conn) {
	// 对count变量进行原子加 1
	// 原子操作可以在并发环境安全的执行
	atomic.AddInt64(&connCount, 1)
	defer func() {
		// 对count变量原子减去1
		atomic.AddInt64(&connCount, -1)
		clientAddr := conn.RemoteAddr().String()
		r := recover()
		if err, ok := r.(error); ok {
			const size = 4096
			buf := make([]byte, size)
			buf = buf[:runtime.Stack(buf, false)] //获得当前goroutine的stacktrace
			log.Error(err)
			log.Errorf("remote %s error %s", clientAddr, err.Error())
			reply := &ErrorReply{
				message: err.Error(),
			}
			_, _ = reply.WriteTo(conn)
		}
		log.Infof("%s conn closed", conn.RemoteAddr().String())
		_ = conn.Close()
		// 原子读取count变量的内容
		pv := atomic.LoadInt64(&connCount)
		log.Infof("conn %d", pv)
		if !s.running && pv == 0 {
			_ = s.Stop()
			s.shutdown()
		}
	}()

	for {
		req, err := NewRequest(conn)
		if err != nil {
			if err == io.EOF {
				// 结束
				return
			}
			reply := &ErrorReply{
				message: err.Error(),
			}
			_, _ = reply.WriteTo(conn)
			continue
		}

		reply := s.request(req)
		if nil == reply {
			continue
		}
		if _, err := reply.WriteTo(conn); err != nil {
			log.Error(err)
			log.Errorf("reply remote %s write error %s", conn.RemoteAddr(), err.Error())
			continue
		}
		if req.Command == "QUIT" || req.Command == "SHUTDOWN" {
			return
		}
	}
}

func (s *IdGenService) getKey(key string) (string, error) {
	keyName := ""
	selectKeySQL := fmt.Sprintf(SelectKeySQLFormat, KeyRecordTableName, key)
	rows, err := s.db.Query(selectKeySQL)
	if err != nil {
		return keyName, err
	}
	defer util.CloseRows(rows)
	for rows.Next() {
		err := rows.Scan(&keyName)
		if err != nil {
			return keyName, err
		}
	}
	if keyName == "" {
		return keyName, fmt.Errorf("%s:not exists key", key)
	}
	return keyName, nil
}

func (s *IdGenService) setKey(key string, id int64) error {
	if len(key) == 0 {
		return fmt.Errorf("%s:invalid key", key)
	}
	_, err := s.getKey(key)
	if err == nil {
		return nil
	} else {
		insertKeySQL := fmt.Sprintf(InsertKeySQLFormat, KeyRecordTableName, key, id)
		_, err = s.db.Exec(insertKeySQL)
		if err != nil {
			return err
		}
		return nil
	}
}

func (s *IdGenService) delKey(key string) error {
	if len(key) == 0 {
		return fmt.Errorf("%s:invalid key", key)
	}
	_, err := s.getKey(key)
	if err == nil {
		deleteKeySQL := fmt.Sprintf(DeleteKeySQLFormat, KeyRecordTableName, key)
		_, err = s.db.Exec(deleteKeySQL)
		if err != nil {
			return err
		}
		return nil
	} else {
		return nil
	}
}
